import torch
from .writer import GGUFWriter, GGMLQuantizationType
from .quant import quantize, QuantError
from .const import GGML_QUANT_VERSION, LlamaFileType
from safetensors.torch import load_file
from tqdm import tqdm
import numpy as np
QUANTIZATION_THRESHOLD = 1024
MAX_TENSOR_NAME_LENGTH = 127
def load_state_dict(path):
    state_dict = load_file(path)
    return {k: v for k, v in state_dict.items()}
def load_model(path, model_arch):
    state_dict = load_state_dict(path)
    writer = GGUFWriter(path=None, arch=model_arch)
    return writer, state_dict, model_arch
def validate_tensor_data(data, key):
    if np.any(np.isnan(data)) or np.any(np.isinf(data)):
        print(f"Warning: Tensor '{key}' contains NaN or Inf values. Skipping.")
        return None
    return data
def handle_tensors(writer, state_dict):
    name_lengths = [(key, len(key)) for key in state_dict.keys()]
    if not name_lengths:
        return
    max_name_len = max(name_lengths, key=lambda x: x[1])[1]
    for key, data in tqdm(state_dict.items(), desc='Processing Tensors'):
        old_dtype = data.dtype
        if old_dtype == torch.bfloat16:
            data = data.to(torch.float32).numpy()
        elif old_dtype in [getattr(torch, 'float8_e4m3fn', '_invalid'),
            getattr(torch, 'float8_e5m2', '_invalid')]:
            data = data.to(torch.float16).numpy()
        else:
            data = data.numpy()
        data = validate_tensor_data(data, key)
        if data is None:
            continue
        data_qtype = (GGMLQuantizationType.BF16 if old_dtype == torch.
            bfloat16 else GGMLQuantizationType.F16)
        n_params = data.size
        if old_dtype in (torch.float32, torch.bfloat16):
            if len(data.shape) == 1 or n_params <= QUANTIZATION_THRESHOLD:
                data_qtype = GGMLQuantizationType.F32
        try:
            data = quantize(data, data_qtype)
        except (AttributeError, QuantError) as e:
            tqdm.write(f'Quantization error ({e}), falling back to F16')
            data_qtype = GGMLQuantizationType.F16
            data = quantize(data, data_qtype)
        shape_str = f"{{{', '.join(map(str, reversed(data.shape)))}}}"
        tqdm.write(
            f'{key.ljust(max_name_len)} {old_dtype} --> {data_qtype.name}, shape = {shape_str}'
            )
        writer.add_tensor(key, data, raw_dtype=data_qtype)
import os
safetensors_files = [file for file in os.listdir() if file.endswith(
    '.safetensors')]
if safetensors_files:
    print('Safetensors file(s) available. Select which one to convert:')
    for index, file_name in enumerate(safetensors_files, start=1):
        print(f'{index}. {file_name}')
    choice = input(f'Enter your choice (1 to {len(safetensors_files)}): ')
    try:
        choice_index = int(choice) - 1
        selected_file = safetensors_files[choice_index]
        print(f'Model file: {selected_file} is selected!')
        path = selected_file
        ask = input('Assign a name for the model (Y/n)? ')
        if ask.lower() == 'y':
            given = input('Enter a model name: ')
        else:
            given = None
        writer, state_dict, _ = load_model(path, given)
        writer.add_quantization_version(GGML_QUANT_VERSION)
        first_tensor_dtype = next(iter(state_dict.values())).dtype
        file_type = (LlamaFileType.MOSTLY_BF16 if first_tensor_dtype ==
            torch.bfloat16 else LlamaFileType.MOSTLY_F16)
        out_path = f'{os.path.splitext(path)[0]}-{file_type.name.lower()}.gguf'
        writer.add_file_type(file_type)
        if os.path.isfile(out_path):
            input(
                'Output file exists. Press Enter to overwrite or Ctrl+C to abort.'
                )
        handle_tensors(writer, state_dict)
        writer.write_header_to_file(path=out_path)
        writer.write_kv_data_to_file()
        writer.write_tensors_to_file(progress=True)
        writer.close()
        print(f'Conversion completed: {out_path}')
    except (ValueError, IndexError):
        print('Invalid choice. Please enter a valid number.')
else:
    print('No safetensors files are available in the current directory.')
    input('--- Press ENTER To Exit ---')